{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# REINFORCE in pytorch\n",
    "\n",
    "Just like we did before for q-learning, this time we'll design a lasagne network to learn `CartPole-v0` via policy gradient (REINFORCE).\n",
    "\n",
    "Most of the code in this notebook is taken from approximate qlearning, so you'll find it more or less familiar and even simpler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: THEANO_FLAGS='floatX=float32'\n"
     ]
    }
   ],
   "source": [
    "%env THEANO_FLAGS='floatX=float32'\n",
    "import os\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\"))==0:\n",
    "    !bash ../xvfb start\n",
    "    %env DISPLAY=:1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f247570f5f8>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD8CAYAAAB9y7/cAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEnBJREFUeJzt3XGs3eV93/H3ZzaBLMlqCBfk2WYmrbeGTouhd8QR00Qh\nbYFVNZWaCTY1KEK6TCJSokZboZPWRBpSK61hi7ahuIXGqbIQRpJhIdrUc4iq/BGISRxicChOYoVb\ne/hmAZIsGhvkuz/Oc8OZfXzv8b33+Po+eb+ko/P7Pb/n/M73wYfP/fm5v8cnVYUkqT9/Y7ULkCRN\nhgEvSZ0y4CWpUwa8JHXKgJekThnwktSpiQV8kuuSPJPkcJI7JvU+kqTRMon74JOsA/4K+GVgFvgy\ncHNVPb3ibyZJGmlSV/BXAoer6ltV9X+A+4GdE3ovSdII6yd03k3Ac0P7s8DbT9X5wgsvrK1bt06o\nFElae44cOcJ3v/vdLOcckwr4UUX9f3NBSWaAGYBLLrmE/fv3T6gUSVp7pqenl32OSU3RzAJbhvY3\nA0eHO1TVrqqarqrpqampCZUhST+9JhXwXwa2Jbk0yeuAm4A9E3ovSdIIE5miqapXkrwX+BywDriv\nqp6axHtJkkab1Bw8VfUI8Mikzi9JWpgrWSWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQB\nL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdWpZX9mX5Ajw\nA+BV4JWqmk5yAfApYCtwBPinVfXC8sqUJJ2ulbiC/6Wq2l5V023/DmBfVW0D9rV9SdIZNokpmp3A\n7ra9G7hxAu8hSVrEcgO+gL9I8kSSmdZ2cVUdA2jPFy3zPSRJS7CsOXjgqqo6muQiYG+Sb4z7wvYD\nYQbgkksuWWYZkqQTLesKvqqOtufjwGeBK4Hnk2wEaM/HT/HaXVU1XVXTU1NTyylDkjTCkgM+yRuS\nvGl+G/gV4CCwB7ildbsFeGi5RUqSTt9ypmguBj6bZP48/6Wq/jzJl4EHktwKfAd41/LLlCSdriUH\nfFV9C3jbiPb/CVy7nKIkScvnSlZJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4\nSeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpU4sGfJL7khxP\ncnCo7YIke5M8257Pb+1J8pEkh5M8meSKSRYvSTq1ca7gPwZcd0LbHcC+qtoG7Gv7ANcD29pjBrhn\nZcqUJJ2uRQO+qv4S+N4JzTuB3W17N3DjUPvHa+BLwIYkG1eqWEnS+JY6B39xVR0DaM8XtfZNwHND\n/WZb20mSzCTZn2T/3NzcEsuQJJ3KSv+SNSPaalTHqtpVVdNVNT01NbXCZUiSlhrwz89PvbTn4619\nFtgy1G8zcHTp5UmSlmqpAb8HuKVt3wI8NNT+7nY3zQ7gpfmpHEnSmbV+sQ5JPglcDVyYZBb4PeD3\ngQeS3Ap8B3hX6/4IcANwGPgR8J4J1CxJGsOiAV9VN5/i0LUj+hZw+3KLkiQtnytZJalTBrwkdcqA\nl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ\n6pQBL0mdMuAlqVMGvCR1atGAT3JfkuNJDg61fTDJXyc50B43DB27M8nhJM8k+dVJFS5JWtg4V/Af\nA64b0X53VW1vj0cAklwG3AT8QnvNf06ybqWKlSSNb9GAr6q/BL435vl2AvdX1ctV9W3gMHDlMuqT\nJC3Rcubg35vkyTaFc35r2wQ8N9RntrWdJMlMkv1J9s/NzS2jDEnSKEsN+HuAnwW2A8eAP2ztGdG3\nRp2gqnZV1XRVTU9NTS2xDEnSqSwp4Kvq+ap6tap+DPwRr03DzAJbhrpuBo4ur0RJ0lIsKeCTbBza\n/Q1g/g6bPcBNSc5NcimwDXh8eSVKkpZi/WIdknwSuBq4MMks8HvA1Um2M5h+OQLcBlBVTyV5AHga\neAW4vapenUzpkqSFLBrwVXXziOZ7F+h/F3DXcoqSJC2fK1klqVMGvCR1yoCXpE4Z8JLUKQNekjpl\nwEtSpxa9TVLq1RO7bjup7RdnProKlUiT4RW8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0md\nMuAlqVMGvCR1yoCXpE4Z8JLUqUUDPsmWJI8mOZTkqSTva+0XJNmb5Nn2fH5rT5KPJDmc5MkkV0x6\nEJKkk41zBf8K8IGqeiuwA7g9yWXAHcC+qtoG7Gv7ANcD29pjBrhnxauWJC1q0YCvqmNV9ZW2/QPg\nELAJ2Ansbt12Aze27Z3Ax2vgS8CGJBtXvHJJ0oJOaw4+yVbgcuAx4OKqOgaDHwLARa3bJuC5oZfN\ntrYTzzWTZH+S/XNzc6dfuSRpQWMHfJI3Ap8G3l9V31+o64i2OqmhaldVTVfV9NTU1LhlSJLGNFbA\nJzmHQbh/oqo+05qfn596ac/HW/sssGXo5ZuBoytTriRpXOPcRRPgXuBQVX146NAe4Ja2fQvw0FD7\nu9vdNDuAl+anciRJZ844X9l3FfBbwNeTHGhtvwv8PvBAkluB7wDvasceAW4ADgM/At6zohVLksay\naMBX1RcZPa8OcO2I/gXcvsy6JEnL5EpWSeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcDr\np9Yvznz0pLYndt22CpVIk2HAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtS\npwx4SerUOF+6vSXJo0kOJXkqyfta+weT/HWSA+1xw9Br7kxyOMkzSX51kgOQJI02zpduvwJ8oKq+\nkuRNwBNJ9rZjd1fVvxvunOQy4CbgF4C/Dfz3JH+3ql5dycIlSQtb9Aq+qo5V1Vfa9g+AQ8CmBV6y\nE7i/ql6uqm8Dh4ErV6JYSdL4TmsOPslW4HLgsdb03iRPJrkvyfmtbRPw3NDLZln4B4IkaQLGDvgk\nbwQ+Dby/qr4P3AP8LLAdOAb84XzXES+vEeebSbI/yf65ubnTLlyStLCxAj7JOQzC/RNV9RmAqnq+\nql6tqh8Df8Rr0zCzwJahl28Gjp54zqraVVXTVTU9NTW1nDFIkkYY5y6aAPcCh6rqw0PtG4e6/QZw\nsG3vAW5Kcm6SS4FtwOMrV7IkaRzj3EVzFfBbwNeTHGhtvwvcnGQ7g+mXI8BtAFX1VJIHgKcZ3IFz\nu3fQSNKZt2jAV9UXGT2v/sgCr7kLuGsZdUmSlsmVrJLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalT\nBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4NWdJGM/JvF66WxhwEtSp8b5wg+paw8f\nm/nJ9q9t3LWKlUgryyt4/VQbDnepNwa8NMTAV0/G+dLt85I8nuRrSZ5K8qHWfmmSx5I8m+RTSV7X\n2s9t+4fb8a2THYK0cpyiUU/GuYJ/Gbimqt4GbAeuS7ID+APg7qraBrwA3Nr63wq8UFU/B9zd+kln\nJQNdPRvnS7cL+GHbPac9CrgG+GetfTfwQeAeYGfbBngQ+I9J0s4jnVWmb9sFvBbyH1y1SqSVN9Yc\nfJJ1SQ4Ax4G9wDeBF6vqldZlFtjUtjcBzwG04y8Bb17JoiVJixsr4Kvq1araDmwGrgTeOqpbex61\n+uOkq/ckM0n2J9k/Nzc3br2SpDGd1l00VfUi8AVgB7AhyfwUz2bgaNueBbYAtOM/A3xvxLl2VdV0\nVU1PTU0trXpJ0imNcxfNVJINbfv1wDuBQ8CjwG+2brcAD7XtPW2fdvzzzr9L0pk3zkrWjcDuJOsY\n/EB4oKoeTvI0cH+Sfwt8Fbi39b8X+NMkhxlcud80gbolSYsY5y6aJ4HLR7R/i8F8/Int/xt414pU\nJ0laMleySlKnDHhJ6pQBL0md8p8LVne8aUsa8ApekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqA\nl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHVqnC/dPi/J40m+luSpJB9q7R9L8u0k\nB9pje2tPko8kOZzkySRXTHoQkqSTjfPvwb8MXFNVP0xyDvDFJH/Wjv3LqnrwhP7XA9va4+3APe1Z\nknQGLXoFXwM/bLvntMdC36iwE/h4e92XgA1JNi6/VEnS6RhrDj7JuiQHgOPA3qp6rB26q03D3J3k\n3Na2CXhu6OWzrU2SdAaNFfBV9WpVbQc2A1cm+fvAncDPA/8QuAD4ndY9o05xYkOSmST7k+yfm5tb\nUvGSpFM7rbtoqupF4AvAdVV1rE3DvAz8CXBl6zYLbBl62Wbg6Ihz7aqq6aqanpqaWlLxkqRTG+cu\nmqkkG9r264F3At+Yn1dPEuBG4GB7yR7g3e1umh3AS1V1bCLVS5JOaZy7aDYCu5OsY/AD4YGqejjJ\n55NMMZiSOQD8i9b/EeAG4DDwI+A9K1+2JGkxiwZ8VT0JXD6i/ZpT9C/g9uWXJklaDleySlKnDHhJ\n6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6RO\nGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0aO+CTrEvy1SQPt/1LkzyW5Nkkn0ryutZ+bts/3I5vnUzp\nkqSFnM4V/PuAQ0P7fwDcXVXbgBeAW1v7rcALVfVzwN2tnyTpDBsr4JNsBv4J8MdtP8A1wIOty27g\nxra9s+3Tjl/b+kuSzqD1Y/b798C/At7U9t8MvFhVr7T9WWBT294EPAdQVa8kean1/+7wCZPMADNt\n9+UkB5c0grPfhZww9k70Oi7od2yOa235O0lmqmrXUk+waMAn+TXgeFU9keTq+eYRXWuMY681DIre\n1d5jf1VNj1XxGtPr2HodF/Q7Nse19iTZT8vJpRjnCv4q4NeT3ACcB/wtBlf0G5Ksb1fxm4Gjrf8s\nsAWYTbIe+Bnge0stUJK0NIvOwVfVnVW1uaq2AjcBn6+qfw48Cvxm63YL8FDb3tP2acc/X1UnXcFL\nkiZrOffB/w7w20kOM5hjv7e13wu8ubX/NnDHGOda8l9B1oBex9bruKDfsTmutWdZY4sX15LUJ1ey\nSlKnVj3gk1yX5Jm28nWc6ZyzSpL7khwfvs0zyQVJ9rZVvnuTnN/ak+QjbaxPJrli9SpfWJItSR5N\ncijJU0ne19rX9NiSnJfk8SRfa+P6UGvvYmV2ryvOkxxJ8vUkB9qdJWv+swiQZEOSB5N8o/2/9o6V\nHNeqBnySdcB/Aq4HLgNuTnLZata0BB8Drjuh7Q5gX1vlu4/Xfg9xPbCtPWaAe85QjUvxCvCBqnor\nsAO4vf3ZrPWxvQxcU1VvA7YD1yXZQT8rs3tecf5LVbV96JbItf5ZBPgPwJ9X1c8Db2PwZ7dy46qq\nVXsA7wA+N7R/J3Dnata0xHFsBQ4O7T8DbGzbG4Fn2vZHgZtH9TvbHwzukvrlnsYG/E3gK8DbGSyU\nWd/af/K5BD4HvKNtr2/9stq1n2I8m1sgXAM8zGBNypofV6vxCHDhCW1r+rPI4Jbzb5/4330lx7Xa\nUzQ/WfXaDK+IXcsurqpjAO35ota+Jsfb/vp+OfAYHYytTWMcAI4De4FvMubKbGB+ZfbZaH7F+Y/b\n/tgrzjm7xwWDxZJ/keSJtgoe1v5n8S3AHPAnbVrtj5O8gRUc12oH/FirXjuy5sab5I3Ap4H3V9X3\nF+o6ou2sHFtVvVpV2xlc8V4JvHVUt/a8JsaVoRXnw80juq6pcQ25qqquYDBNcXuSf7xA37UytvXA\nFcA9VXU58L9Y+Lby0x7Xagf8/KrXecMrYtey55NsBGjPx1v7mhpvknMYhPsnquozrbmLsQFU1YvA\nFxj8jmFDW3kNo1dmc5avzJ5fcX4EuJ/BNM1PVpy3PmtxXABU1dH2fBz4LIMfzGv9szgLzFbVY23/\nQQaBv2LjWu2A/zKwrf2m/3UMVsruWeWaVsLwat4TV/m+u/02fAfw0vxfxc42ScJg0dqhqvrw0KE1\nPbYkU0k2tO3XA+9k8IutNb0yuzpecZ7kDUneNL8N/ApwkDX+Wayq/wE8l+TvtaZrgadZyXGdBb9o\nuAH4KwbzoP96tetZQv2fBI4B/5fBT9hbGcxl7gOebc8XtL5hcNfQN4GvA9OrXf8C4/pHDP769yRw\noD1uWOtjA/4B8NU2roPAv2ntbwEeBw4D/xU4t7Wf1/YPt+NvWe0xjDHGq4GHexlXG8PX2uOp+ZxY\n65/FVut2YH/7PP434PyVHJcrWSWpU6s9RSNJmhADXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXK\ngJekTv0/s4d/S2dPqMYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f247edf9c18>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "import numpy as np, pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "env = gym.make(\"CartPole-v0\").env\n",
    "env.reset()\n",
    "n_actions = env.action_space.n\n",
    "state_dim = env.observation_space.shape\n",
    "\n",
    "plt.imshow(env.render(\"rgb_array\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building the network for REINFORCE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states. Let's define such a model below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#TODO your code here\n",
    "agent = nn.Sequential()\n",
    "agent.add_module('hid', nn.Linear(state_dim[0], 128))\n",
    "agent.add_module('relu', nn.ELU())\n",
    "agent.add_module('logits', nn.Linear(128, n_actions))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_proba(states):\n",
    "    \"\"\" \n",
    "    Predict action probabilities given states.\n",
    "    :param states: numpy array of shape [batch, state_shape]\n",
    "    :returns: numpy array of shape [batch, n_actions]\n",
    "    \"\"\"\n",
    "    #TODO your code here\n",
    "    states = Variable(torch.from_numpy(states.astype('float32')))\n",
    "    logits = agent(states)\n",
    "    return F.softmax(logits, -1).data.numpy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_states = np.array([env.reset() for _ in range(5)])\n",
    "test_probas = predict_proba(test_states)\n",
    "assert isinstance(test_probas, np.ndarray), \"you must return np array and not %s\" % type(test_probas)\n",
    "assert tuple(test_probas.shape) == (test_states.shape[0], n_actions), \"wrong output shape: %s\" % np.shape(test_probas)\n",
    "assert np.allclose(np.sum(test_probas, axis = 1), 1), \"probabilities do not sum to 1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Play the game\n",
    "\n",
    "We can now use our newly built agent to play the game."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate_session(t_max=1000):\n",
    "    \"\"\" \n",
    "    play a full session with REINFORCE agent and train at the session end.\n",
    "    returns sequences of states, actions andrewards\n",
    "    \"\"\"\n",
    "    \n",
    "    #arrays to record session\n",
    "    states,actions,rewards = [],[],[]\n",
    "    \n",
    "    s = env.reset()\n",
    "    \n",
    "    for t in range(t_max):\n",
    "        \n",
    "        #action probabilities array aka pi(a|s)\n",
    "        action_probas = predict_proba(np.array([s]))[0] \n",
    "        \n",
    "        #TODO a = <sample action with given probabilities>\n",
    "        a = np.random.choice(n_actions, p=action_probas)\n",
    "        \n",
    "        new_s,r,done,info = env.step(a)\n",
    "        \n",
    "        #record session history to train later\n",
    "        states.append(s)\n",
    "        actions.append(a)\n",
    "        rewards.append(r)\n",
    "        \n",
    "        s = new_s\n",
    "        if done: break\n",
    "            \n",
    "    return states, actions, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test it\n",
    "states, actions, rewards = generate_session()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing cumulative rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_cumulative_rewards(rewards, #rewards at each step\n",
    "                           gamma = 0.99 #discount for reward\n",
    "                           ):\n",
    "    \"\"\"\n",
    "    take a list of immediate rewards r(s,a) for the whole session \n",
    "    compute cumulative returns (a.k.a. G(s,a) in Sutton '16)\n",
    "    G_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n",
    "    \n",
    "    The simple way to compute cumulative rewards is to iterate from last to first time tick\n",
    "    and compute G_t = r_t + gamma*G_{t+1} recurrently\n",
    "    \n",
    "    You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n",
    "    \"\"\"\n",
    "    \n",
    "    #TODO <your code here>   \n",
    "    #return <array of cumulative rewards>\n",
    "    \n",
    "    \n",
    "    G = [rewards[-1]]\n",
    "    for r in reversed(rewards[:-1]):\n",
    "        G.insert( 0, r + gamma * G[0] )\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "looks good!\n"
     ]
    }
   ],
   "source": [
    "get_cumulative_rewards(rewards)\n",
    "assert len(get_cumulative_rewards(range(100))) == 100\n",
    "assert np.allclose(get_cumulative_rewards([0,0,1,0,0,1,0],gamma=0.9),[1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards([0,0,1,-2,3,-4,0],gamma=0.5), [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n",
    "assert np.allclose(get_cumulative_rewards([0,0,1,2,3,4,0],gamma=0), [0, 0, 1, 2, 3, 4, 0])\n",
    "print(\"looks good!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loss function and updates\n",
    "\n",
    "We now need to define objective and update over policy gradient.\n",
    "\n",
    "Our objective function is\n",
    "\n",
    "$$ J \\approx  { 1 \\over N } \\sum  _{s_i,a_i} \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "\n",
    "Following the REINFORCE algorithm, we can define our objective as follows: \n",
    "\n",
    "$$ \\hat J \\approx { 1 \\over N } \\sum  _{s_i,a_i} log \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n",
    "\n",
    "When you compute gradient of that function over network weights $ \\theta $, it will become exactly the policy gradient.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_one_hot(y, n_dims=None):\n",
    "    \"\"\" Takes an integer vector and converts it to 1-hot matrix. \"\"\"\n",
    "    y_tensor = y.data if isinstance(y, Variable) else y\n",
    "    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)\n",
    "    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1\n",
    "    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)\n",
    "    return Variable(y_one_hot) if isinstance(y, Variable) else y_one_hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = torch.optim.Adam(list(agent.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_on_session(states, actions, rewards, gamma = 0.99):\n",
    "    \"\"\"\n",
    "    Takes a sequence of states, actions and rewards produced by generate_session.\n",
    "    Updates agent's weights by following the policy gradient above \"\"\"\n",
    "    \n",
    "    # cast everything into a variable\n",
    "    states = Variable(torch.FloatTensor(states))\n",
    "    actions = Variable(torch.IntTensor(actions))\n",
    "    cumulative_returns = np.array(get_cumulative_rewards(rewards, gamma))\n",
    "    cumulative_returns = Variable(torch.FloatTensor(cumulative_returns))\n",
    "    \n",
    "    # predict logits, probas and log-probas using an agent. \n",
    "    #TODO your code here\n",
    "    logits = agent(states)\n",
    "    probas = F.softmax(logits, -1)\n",
    "    logprobas = F.log_softmax(logits, -1)\n",
    "    \n",
    "    assert all(isinstance(v, Variable) for v in [logits, probas, logprobas]), \\\n",
    "        \"please use compute using torch tensors and don't use predict_proba function\"\n",
    "    \n",
    "    # select log-probabilities for chosen actions, log pi(a_i|s_i)\n",
    "    logprobas_for_actions = torch.sum(logprobas * to_one_hot(actions), dim = 1)\n",
    "    \n",
    "    # REINFORCE objective function\n",
    "    J_hat = torch.mean(logprobas_for_actions * cumulative_returns)#TODO#<policy objective as in the formula for J_hat. Please use mean, not sum.>\n",
    "    \n",
    "    #regularize with entropy\n",
    "    #TODOentropy = <compute mean entropy of probas. Don't forget the sign!>\n",
    "    entropy_reg = - torch.mean(torch.sum(probas * logprobas, dim = 1))\n",
    "    \n",
    "    loss = - J_hat - 0.1 * entropy_reg\n",
    "    \n",
    "    # Gradient descent step\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    \n",
    "    return np.sum(rewards)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The actual training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean reward:21.380\n",
      "mean reward:23.330\n",
      "mean reward:50.980\n",
      "mean reward:112.400\n",
      "mean reward:144.060\n",
      "mean reward:95.870\n",
      "mean reward:183.300\n",
      "mean reward:127.050\n",
      "mean reward:120.070\n",
      "mean reward:123.760\n",
      "mean reward:142.540\n",
      "mean reward:587.320\n",
      "You Win!\n"
     ]
    }
   ],
   "source": [
    "for i in range(100):\n",
    "    \n",
    "    rewards = [train_on_session(*generate_session()) for _ in range(100)] #generate new sessions\n",
    "    \n",
    "    print (\"mean reward:%.3f\"%(np.mean(rewards)))\n",
    "\n",
    "    if np.mean(rewards) > 500:\n",
    "        print (\"You Win!\") # but you can train even further\n",
    "        break\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#record sessions\n",
    "import gym.wrappers\n",
    "env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),directory=\"videos\",force=True)\n",
    "sessions = [generate_session() for _ in range(100)]\n",
    "env.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<video width=\"640\" height=\"480\" controls>\n",
       "  <source src=\"./videos/openaigym.video.0.7822.video000027.mp4\" type=\"video/mp4\">\n",
       "</video>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#show video\n",
    "from IPython.display import HTML\n",
    "import os\n",
    "\n",
    "video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./videos/\")))\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(\"./videos/\"+video_names[-1])) #this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
